{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "LTj2A0RUQFNo"
   },
   "source": [
    "# Convert our BiRefNet weights to onnx format.\n",
    "\n",
    "> This colab file is modified from [Kazuhito00](https://github.com/Kazuhito00)'s nice work.\n",
    "\n",
    "> Repo: https://github.com/Kazuhito00/BiRefNet-ONNX-Sample  \n",
    "> Original Colab: https://colab.research.google.com/github/Kazuhito00/BiRefNet-ONNX-Sample/blob/main/Convert2ONNX.ipynb\n",
    "\n",
    "+ Transforming a standard BiRefNet on GPU needs **19.7GB** GPU memory.\n",
    "+ Currently, Colab with 12.7GB RAM / 15GB GPU Mem cannot hold the transformation of BiRefNet in default setting. So, I take BiRefNet with swin_v1_tiny backbone as an example on Colab."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Online Colab version: https://colab.research.google.com/drive/1z6OruR52LOvDDpnp516F-N4EyPGrp5om"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\u001b[0m\u001b[33m\n",
      "\u001b[0m"
     ]
    }
   ],
   "source": [
    "!pip install -q onnx onnxscript onnxruntime-gpu==1.18.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/root/autodl-tmp/BiRefNet\n"
     ]
    }
   ],
   "source": [
    "cd .."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "id": "781JHjLJmveh"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "\n",
    "weights_file = 'BiRefNet-matting-epoch_100.pth'  # https://github.com/ZhengPeng7/BiRefNet/releases/download/v1/BiRefNet-general-bb_swin_v1_tiny-epoch_232.pth\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('config.py') as fp:\n",
    "    file_lines = fp.read()\n",
    "if 'swin_v1_tiny' in weights_file:\n",
    "    print('Set `swin_v1_tiny` as the backbone.')\n",
    "    file_lines = file_lines.replace(\n",
    "        '''\n",
    "            'pvt_v2_b2', 'pvt_v2_b5',               # 9-bs10, 10-bs5\n",
    "        ][6]\n",
    "        ''',\n",
    "        '''\n",
    "            'pvt_v2_b2', 'pvt_v2_b5',               # 9-bs10, 10-bs5\n",
    "        ][3]\n",
    "        ''',\n",
    "    )\n",
    "    with open('config.py', mode=\"w\") as fp:\n",
    "        fp.write(file_lines)\n",
    "else:\n",
    "    file_lines = file_lines.replace(\n",
    "        '''\n",
    "            'pvt_v2_b2', 'pvt_v2_b5',               # 9-bs10, 10-bs5\n",
    "        ][3]\n",
    "        ''',\n",
    "        '''\n",
    "            'pvt_v2_b2', 'pvt_v2_b5',               # 9-bs10, 10-bs5\n",
    "        ][6]\n",
    "        ''',\n",
    "    )\n",
    "    with open('config.py', mode=\"w\") as fp:\n",
    "        fp.write(file_lines)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "id": "7lFgKfPS8Icy"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/miniconda3/envs/birefnet/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "/root/miniconda3/envs/birefnet/lib/python3.9/site-packages/timm/models/layers/__init__.py:48: FutureWarning: Importing from timm.models.layers is deprecated, please import via timm.layers\n",
      "  warnings.warn(f\"Importing from {__name__} is deprecated, please import via timm.layers\", FutureWarning)\n",
      "/root/miniconda3/envs/birefnet/lib/python3.9/site-packages/timm/models/registry.py:4: FutureWarning: Importing from timm.models.registry is deprecated, please import via timm.models\n",
      "  warnings.warn(f\"Importing from {__name__} is deprecated, please import via timm.models\", FutureWarning)\n"
     ]
    }
   ],
   "source": [
    "from utils import check_state_dict\n",
    "from models.birefnet import BiRefNet\n",
    "\n",
    "\n",
    "birefnet = BiRefNet(bb_pretrained=False)\n",
    "state_dict = torch.load('./{}'.format(weights_file), map_location=device, weights_only=True)\n",
    "state_dict = check_state_dict(state_dict)\n",
    "birefnet.load_state_dict(state_dict)\n",
    "\n",
    "torch.set_float32_matmul_precision(['high', 'highest'][0])\n",
    "\n",
    "birefnet.to(device)\n",
    "birefnet.half()\n",
    "_ = birefnet.eval()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "JVgJAdgxQVJW"
   },
   "source": [
    "# Process deform_conv2d in the conversion to ONNX"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Cloning into 'deform_conv2d_onnx_exporter'...\n",
      "remote: Enumerating objects: 205, done.\u001b[K\n",
      "remote: Counting objects: 100% (7/7), done.\u001b[K\n",
      "remote: Total 205 (delta 6), reused 6 (delta 6), pack-reused 198 (from 1)\u001b[K\n",
      "Receiving objects: 100% (205/205), 36.21 KiB | 170.00 KiB/s, done.\n",
      "Resolving deltas: 100% (102/102), done.\n"
     ]
    }
   ],
   "source": [
    "!git clone https://github.com/masamitsu-murase/deform_conv2d_onnx_exporter\n",
    "%cp deform_conv2d_onnx_exporter/src/deform_conv2d_onnx_exporter.py .\n",
    "!rm -rf deform_conv2d_onnx_exporter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('deform_conv2d_onnx_exporter.py') as fp:\n",
    "    file_lines = fp.read()\n",
    "\n",
    "file_lines = file_lines.replace(\n",
    "    \"return sym_help._get_tensor_dim_size(tensor, dim)\",\n",
    "    '''\n",
    "    tensor_dim_size = sym_help._get_tensor_dim_size(tensor, dim)\n",
    "    if tensor_dim_size == None and (dim == 2 or dim == 3):\n",
    "        import typing\n",
    "        from torch import _C\n",
    "\n",
    "        x_type = typing.cast(_C.TensorType, tensor.type())\n",
    "        x_strides = x_type.strides()\n",
    "\n",
    "        tensor_dim_size = x_strides[2] if dim == 3 else x_strides[1] // x_strides[2]\n",
    "    elif tensor_dim_size == None and (dim == 0):\n",
    "        import typing\n",
    "        from torch import _C\n",
    "\n",
    "        x_type = typing.cast(_C.TensorType, tensor.type())\n",
    "        x_strides = x_type.strides()\n",
    "        tensor_dim_size = x_strides[3]\n",
    "\n",
    "    return tensor_dim_size\n",
    "    ''',\n",
    ")\n",
    "\n",
    "with open('deform_conv2d_onnx_exporter.py', mode=\"w\") as fp:\n",
    "    fp.write(file_lines)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "id": "vJiZv0L75kTe"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/autodl-tmp/BiRefNet/models/backbones/swin_v1.py:441: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n",
      "  if W % self.patch_size[1] != 0:\n",
      "/root/autodl-tmp/BiRefNet/models/backbones/swin_v1.py:443: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n",
      "  if H % self.patch_size[0] != 0:\n",
      "/root/autodl-tmp/BiRefNet/models/backbones/swin_v1.py:379: TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n",
      "  Hp = int(np.ceil(H / self.window_size)) * self.window_size\n",
      "/root/autodl-tmp/BiRefNet/models/backbones/swin_v1.py:380: TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n",
      "  Wp = int(np.ceil(W / self.window_size)) * self.window_size\n",
      "/root/autodl-tmp/BiRefNet/models/backbones/swin_v1.py:216: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n",
      "  assert L == H * W, \"input feature has wrong size\"\n",
      "/root/autodl-tmp/BiRefNet/models/backbones/swin_v1.py:67: TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n",
      "  B = int(windows.shape[0] / (H * W / window_size / window_size))\n",
      "/root/autodl-tmp/BiRefNet/models/backbones/swin_v1.py:254: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n",
      "  if pad_r > 0 or pad_b > 0:\n",
      "/root/autodl-tmp/BiRefNet/models/backbones/swin_v1.py:287: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n",
      "  assert L == H * W, \"input feature has wrong size\"\n",
      "/root/autodl-tmp/BiRefNet/models/backbones/swin_v1.py:292: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n",
      "  pad_input = (H % 2 == 1) or (W % 2 == 1)\n",
      "/root/autodl-tmp/BiRefNet/models/backbones/swin_v1.py:293: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n",
      "  if pad_input:\n",
      "/root/miniconda3/envs/birefnet/lib/python3.9/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3483.)\n",
      "  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "============= Diagnostic Run torch.onnx.export version 2.0.1+cu118 =============\n",
      "verbose: False, log level: Level.ERROR\n",
      "======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from torchvision.ops.deform_conv import DeformConv2d\n",
    "import deform_conv2d_onnx_exporter\n",
    "\n",
    "# register deform_conv2d operator\n",
    "deform_conv2d_onnx_exporter.register_deform_conv2d_onnx_op()\n",
    "\n",
    "def convert_to_onnx(net, file_name='output.onnx', input_shape=(1024, 1024), device=device):\n",
    "    input = torch.randn(1, 3, input_shape[0], input_shape[1]).to(device).half()\n",
    "\n",
    "    input_layer_names = ['input_image']\n",
    "    output_layer_names = ['output_image']\n",
    "\n",
    "    torch.onnx.export(\n",
    "        net,\n",
    "        input,\n",
    "        file_name,\n",
    "        verbose=False,\n",
    "        opset_version=17,\n",
    "        input_names=input_layer_names,\n",
    "        output_names=output_layer_names,\n",
    "    )\n",
    "convert_to_onnx(birefnet, weights_file.replace('.pth', '.onnx'), input_shape=(1024, 1024), device=device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "-eU-g40P1zS-"
   },
   "source": [
    "# Load ONNX weights and do the inference."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from PIL import Image\n",
    "from torchvision import transforms\n",
    "\n",
    "\n",
    "transform_image = transforms.Compose([\n",
    "    transforms.Resize((1024, 1024)),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
    "])\n",
    "\n",
    "imagepath = './Helicopter-HR.jpg'\n",
    "image = Image.open(imagepath)\n",
    "image = image.convert(\"RGB\") if image.mode != \"RGB\" else image\n",
    "input_images = transform_image(image).unsqueeze(0).to(device).half()\n",
    "input_images_numpy = input_images.cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "rwzdKX1EfYkd"
   },
   "outputs": [],
   "source": [
    "import onnxruntime\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "\n",
    "providers = ['CPUExecutionProvider'] if device == 'cpu' else ['CUDAExecutionProvider']\n",
    "onnx_session = onnxruntime.InferenceSession(\n",
    "    weights_file.replace('.pth', '.onnx'),\n",
    "    providers=providers\n",
    ")\n",
    "input_name = onnx_session.get_inputs()[0].name\n",
    "print(onnxruntime.get_device(), onnx_session.get_providers())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "DJVtxZUZum4-"
   },
   "outputs": [],
   "source": [
    "from time import time\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "time_st = time()\n",
    "pred_onnx = torch.tensor(\n",
    "    onnx_session.run(None, {input_name: input_images_numpy if device == 'cpu' else input_images_numpy})[-1]\n",
    ").squeeze(0).sigmoid().cpu()\n",
    "print(time() - time_st)\n",
    "\n",
    "plt.imshow(pred_onnx.squeeze(), cmap='gray'); plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    preds = birefnet(input_images)[-1].sigmoid().cpu()\n",
    "plt.imshow(preds.squeeze(), cmap='gray'); plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "diff = abs(preds - pred_onnx)\n",
    "print('sum(diff):', diff.sum())\n",
    "plt.imshow((diff).squeeze(), cmap='gray'); plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "qzYHflt92Bjd"
   },
   "source": [
    "# Efficiency Comparison between .pth and .onnx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "A5IYfT-uzphA",
    "outputId": "2999e345-950e-41b3-ddd3-9f58a71a3f21"
   },
   "outputs": [],
   "source": [
    "%%timeit\n",
    "with torch.no_grad():\n",
    "    preds = birefnet(input_images)[-1].sigmoid().cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "G0Ul4rfNg1za"
   },
   "outputs": [],
   "source": [
    "%%timeit\n",
    "pred_onnx = torch.tensor(\n",
    "    onnx_session.run(None, {input_name: input_images_numpy})[-1]\n",
    ").squeeze(0).sigmoid().cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "T4",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
